{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lhn6RdCkunJY"
      },
      "source": [
        "##### Copyright 2022 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nFvE5hKZuYy6"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cfjBLLxVb61A"
      },
      "source": [
        "# MOSAIC with Model Garden\n",
        "\n",
        "This tutorial demonstrates how to load the\n",
        "[MOSAIC](https://github.com/tensorflow/models/tree/master/official/projects/mosaic) model trained using [Tensorflow Model Garden](https://github.com/tensorflow/models/tree/master/official) library.\n",
        "\n",
        "Tensorflow Model Garden contains a collection of\n",
        "state-of-the-art models, implemented with TensorFlow's high-level APIs. The\n",
        "implementations demonstrate the best practices for modeling, letting users to\n",
        "take full advantage of TensorFlow for their research and product development."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hmslpkDEcPk9"
      },
      "source": [
        "## Install Necessary Dependencies"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WjAMpVVAcTYh"
      },
      "source": [
        "## Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T0JMx3hmcXL5"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.patches as mpatches\n",
        "\n",
        "from PIL import Image\n",
        "from six import BytesIO\n",
        "from IPython import display\n",
        "from urllib.request import urlopen\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "\n",
        "import absl.logging\n",
        "absl.logging.set_verbosity(absl.logging.ERROR)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "02j9_wFYd3H5"
      },
      "source": [
        "## Load Pretrained Model\n",
        "Note that you may also access the pretrained model provided on [TFHub](https://tfhub.dev/google/mosaic/mobilenetmultiavgseg):\n",
        "\n",
        "```\n",
        "keras_layer = hub.KerasLayer(\n",
        "  'https://tfhub.dev/google/mosaic/mobilenetmultiavgseg/2',\n",
        "  signature='serving_default',\n",
        "  output_key='logits')\n",
        "model = tf.keras.Sequential([keras_layer])\n",
        "model.build([None, IMAGE_HEIGHT, IMAGE_WIDTH, 3])\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7MOWG8h5eEh6"
      },
      "source": [
        "### Download SavedModel\n",
        "The model uses the implementation from the TensorFlow Model Garden GitHub repository, and achieves 77.24% mIoU on Cityscapes dataset with 19 classes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WvTrNdt7sjh4"
      },
      "outputs": [],
      "source": [
        "! curl https://storage.googleapis.com/tf_model_garden/vision/mosaic/mosaic_mobilenet_multiavgseg_r1024_ebf64_gp_model.tar.gz --output model.tar.gz"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VEJf6ZTcsREl"
      },
      "outputs": [],
      "source": [
        "# Extract savedmodel\n",
        "! tar -xvf model.tar.gz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xwD9Rd5EebNC"
      },
      "source": [
        "### Load the SavedModel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SvOypbg0dDOa"
      },
      "outputs": [],
      "source": [
        "# Load saved model\n",
        "export_dir = \"saved_model\"\n",
        "imported = tf.saved_model.load(export_dir)\n",
        "model_fn = imported.signatures['serving_default']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Th_BzrD5eoPp"
      },
      "source": [
        "## Run Inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X8ACyzMDdDR7"
      },
      "outputs": [],
      "source": [
        "# Defines helper function to download sample image\n",
        "def load_image_into_numpy_array(path):\n",
        "  \"\"\"Load an image from file into a numpy array.\n",
        "\n",
        "  Puts image into numpy array to feed into tensorflow graph.\n",
        "  Note that by convention we put it into a numpy array with shape\n",
        "  (height, width, channels), where channels=3 for RGB.\n",
        "\n",
        "  Args:\n",
        "    path: the file path to the image\n",
        "\n",
        "  Returns:\n",
        "    uint8 numpy array with shape (img_height, img_width, 3)\n",
        "  \"\"\"\n",
        "  image = None\n",
        "  if(path.startswith('http')):\n",
        "    response = urlopen(path)\n",
        "    image_data = response.read()\n",
        "    image_data = BytesIO(image_data)\n",
        "    image = Image.open(image_data)\n",
        "  else:\n",
        "    image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
        "    image = Image.open(BytesIO(image_data))\n",
        "\n",
        "  (im_width, im_height) = image.size\n",
        "\n",
        "  image = np.array(image.getdata()).reshape(\n",
        "      (1, im_height, im_width, 3)).astype(np.uint8)\n",
        "\n",
        "  return image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n3BXyYune5p6"
      },
      "source": [
        "### Download a Sample Image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D7KLJpIGdDVs"
      },
      "outputs": [],
      "source": [
        "image_path = \"https://storage.googleapis.com/tf_model_garden/vision/mosaic/cityscape_sample.png\"\n",
        "image_array = load_image_into_numpy_array(image_path)\n",
        "image_array.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u2eMxCEfe83m"
      },
      "outputs": [],
      "source": [
        "Image.fromarray(image_array[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gn99qmZAf75S"
      },
      "source": [
        "### Run Inference on Sample Image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gDdehkLeNeqz"
      },
      "outputs": [],
      "source": [
        "outputs = model_fn(inputs=image_array)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LCWNjuFxNvTZ"
      },
      "outputs": [],
      "source": [
        "outputs['logits'].shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l9WHT3WsgePX"
      },
      "source": [
        "Note that the output is in the shape of `[batch_size, height, width, num_classes]`, which is the raw `logits` prediction output for each pixel."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wGusC66mvFw9"
      },
      "outputs": [],
      "source": [
        "detection_results = np.argmax(outputs['logits'][0], axis=-1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TpzoZOJ_u1dw"
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots(figsize=(12, 12))\n",
        "ax.imshow(detection_results)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "05ZTuuAci3Bu"
      },
      "source": [
        "## Model Training\n",
        "\n",
        "Please check out the [TF Model Garden MOSAIC implementation](https://github.com/tensorflow/models/tree/master/official/projects/mosaic) for model training.\n",
        "\n",
        "Please check out the [Image Classification Tutorial](https://github.com/tensorflow/models/blob/master/docs/vision/image_classification.ipynb) for fine-tuning models from the TensorFlow Model Garden package."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kh7Q42fmi6NP"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1G4xIYr_wrx-QXzboQLNT2DGAfA4jNYLQ",
          "timestamp": 1671845366432
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
